cmake_minimum_required(VERSION 3.12)

project(flash-attention LANGUAGES CXX CUDA)

find_program(GPU_dev_info nvidia-smi)
if(NOT GPU_dev_info)
    message(FATAL_ERROR "GPU driver not found.")
endif()

find_program(NVCC nvcc)
if(NOT NVCC)
    message(FATAL_ERROR "NVCC not found. Please make sure CUDA Toolkit is installed.")
endif()

execute_process(COMMAND ${GPU_dev_info} OUTPUT_VARIABLE GPU_dev_version)
string(REGEX MATCH "H100" GPU_H100 ${GPU_dev_version})

execute_process(COMMAND ${NVCC} --version OUTPUT_VARIABLE NVCC_VERSION_OUTPUT)
string(REGEX MATCH "([0-9]+\\.[0-9]+)" NVCC_VERSION ${NVCC_VERSION_OUTPUT})

if(NOT NVCC_VERSION)
    message(FATAL_ERROR "Failed to determine NVCC version.")
endif()

if(GPU_H100 AND NVCC_VERSION GREATER_EQUAL 12)
    set(compute_capability "90a")
    add_definitions(-DHOPPER)
else()
    set(compute_capability "80")
endif()

message("compute_capability: ${compute_capability}")

set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr --use_fast_math -t 8 \
                    -gencode=arch=compute_${compute_capability},code=\\\"sm_${compute_capability},compute_${compute_capability}\\\" \
                    ")

# ################ CUDA ################
find_package(CUDAToolkit REQUIRED)

include_directories("./cutlass/include")
include_directories("./include")
include_directories(${CUDAToolkit_INCLUDE_DIRS})
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include)
link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib64)


if(compute_capability STREQUAL "90a")
    add_library(flash_attn SHARED
        src/flash.cu
    
        src/flash_fwd_hdim128_fp16_sm80.cu
        src/flash_fwd_hdim160_fp16_sm80.cu
        src/flash_fwd_hdim192_fp16_sm80.cu
        src/flash_fwd_hdim224_fp16_sm80.cu
        src/flash_fwd_hdim256_fp16_sm80.cu
        src/flash_fwd_hdim32_fp16_sm80.cu
        src/flash_fwd_hdim64_fp16_sm80.cu
        src/flash_fwd_hdim96_fp16_sm80.cu
        src/flash_fwd_split_hdim128_fp16_sm80.cu
        src/flash_fwd_split_hdim160_fp16_sm80.cu
        src/flash_fwd_split_hdim192_fp16_sm80.cu
        src/flash_fwd_split_hdim224_fp16_sm80.cu
        src/flash_fwd_split_hdim256_fp16_sm80.cu
        src/flash_fwd_split_hdim32_fp16_sm80.cu
        src/flash_fwd_split_hdim64_fp16_sm80.cu
        src/flash_fwd_split_hdim96_fp16_sm80.cu

        hopper/flash_fwd_hdim64_fp16_sm90.cu
        hopper/flash_fwd_hdim128_fp16_sm90.cu
        hopper/flash_fwd_hdim256_fp16_sm90.cu
        hopper/flash_fwd_hdim64_e4m3_sm90.cu
        hopper/flash_fwd_hdim128_e4m3_sm90.cu
        hopper/flash_fwd_hdim256_e4m3_sm90.cu
    )
else()
    add_library(flash_attn SHARED
        src/flash.cu
    
        src/flash_fwd_hdim128_fp16_sm80.cu
        src/flash_fwd_hdim160_fp16_sm80.cu
        src/flash_fwd_hdim192_fp16_sm80.cu
        src/flash_fwd_hdim224_fp16_sm80.cu
        src/flash_fwd_hdim256_fp16_sm80.cu
        src/flash_fwd_hdim32_fp16_sm80.cu
        src/flash_fwd_hdim64_fp16_sm80.cu
        src/flash_fwd_hdim96_fp16_sm80.cu
        src/flash_fwd_split_hdim128_fp16_sm80.cu
        src/flash_fwd_split_hdim160_fp16_sm80.cu
        src/flash_fwd_split_hdim192_fp16_sm80.cu
        src/flash_fwd_split_hdim224_fp16_sm80.cu
        src/flash_fwd_split_hdim256_fp16_sm80.cu
        src/flash_fwd_split_hdim32_fp16_sm80.cu
        src/flash_fwd_split_hdim64_fp16_sm80.cu
        src/flash_fwd_split_hdim96_fp16_sm80.cu
    )
endif()

set_target_properties(flash_attn PROPERTIES CUDA_ARCHITECTURES "${compute_capability}")
